LIME paper: Recurrent Neural Network for Solubility Prediciton

Import packages and set up RNN

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyBboxPatch
from matplotlib.offsetbox import AnnotationBbox
import seaborn as sns
import textwrap
import skunk
import matplotlib as mpl
import numpy as np
import tensorflow as tf
import selfies as sf
import exmol
from dataclasses import dataclass
from rdkit.Chem.Draw import rdDepictor, MolsToGridImage
from rdkit.Chem import MolFromSmiles, MACCSkeys

rdDepictor.SetPreferCoordGen(True)
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#F06060", "#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)
mpl.rcParams["font.size"] = 10
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/master/data/curated-solubility-dataset.csv"
)
features_start_at = list(soldata.columns).index("MolWt")
np.random.seed(0)
2022-05-11 18:16:31.729801: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-05-11 18:16:31.729836: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
# scramble them
soldata = soldata.sample(frac=0.01, random_state=0).reset_index(drop=True)
soldata.head()
ID Name InChI InChIKey SMILES Solubility SD Ocurrences Group MolWt ... NumRotatableBonds NumValenceElectrons NumAromaticRings NumSaturatedRings NumAliphaticRings RingCount TPSA LabuteASA BalabanJ BertzCT
0 B-4206 diuron InChI=1S/C9H10Cl2N2O/c1-13(2)9(14)12-6-3-4-7(1... XMTQQYYKAHVGBJ-UHFFFAOYSA-N CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1 -3.744300 1.227164 5 G4 233.098 ... 1.0 76.0 1.0 0.0 0.0 1.0 32.34 92.603980 2.781208 352.665233
1 F-988 7-(3-amino-3-methylazetidin-1-yl)-8-chloro-1-c... InChI=1S/C17H17ClFN3O3/c1-17(20)6-21(7-17)14-1... DUNZFXZSFJLIKR-UHFFFAOYSA-N CC1(N)CN(C2=C(Cl)C3=C(C=C2F)C(=O)C(C(=O)O)=CN3... -5.330000 0.000000 1 G1 365.792 ... 3.0 132.0 2.0 2.0 2.0 4.0 88.56 147.136366 2.001398 973.487509
2 C-1996 4-acetoxybiphenyl; 4-biphenylyl acetate InChI=1S/C14H12O2/c1-11(15)16-14-9-7-13(8-10-1... MISFQCBPASYYGV-UHFFFAOYSA-N CC(=O)OC1=CC=C(C=C1)C2=CC=CC=C2 -4.400000 0.000000 1 G1 212.248 ... 2.0 80.0 2.0 0.0 0.0 2.0 26.30 94.493449 2.228677 471.848345
3 A-3055 methane dimolybdenum InChI=1S/CH4.2Mo/h1H4;; JAGQSESDQXCFCH-UHFFFAOYSA-N C.[Mo].[Mo] -3.420275 0.409223 2 G3 207.923 ... 0.0 20.0 0.0 0.0 0.0 0.0 0.00 49.515427 -0.000000 2.754888
4 A-2575 ethyl 4-[[(methylphenylamino)methylene]amino]b... InChI=1S/C17H18N2O2/c1-3-21-17(20)14-9-11-15(1... GNGYPJUKIKDJQT-UHFFFAOYSA-N CCOC(=O)c1ccc(cc1)N=CN(C)c2ccccc2 -5.450777 0.000000 1 G1 282.343 ... 5.0 108.0 2.0 0.0 0.0 2.0 41.90 124.243431 2.028889 606.447052

5 rows × 26 columns

selfies_list = []
for s in soldata.SMILES:
    try:
        selfies_list.append(sf.encoder(exmol.sanitize_smiles(s)[1]))
    except sf.EncoderError:
        selfies_list.append(None)
len(selfies_list)
100
basic = set(exmol.get_basic_alphabet())
data_vocab = set(
    sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])
)
vocab = ['[nop]']
vocab.extend(list(data_vocab.union(basic)))
vocab_stoi = {o: i for o, i in zip(vocab, range(len(vocab)))}


def selfies2ints(s):
    result = []
    for token in sf.split_selfies(s):
        if token == '.':
            continue  # ?
        if token in vocab_stoi:
            result.append(vocab_stoi[token])
        else:
            result.append(np.nan)
            # print('Warning')
    return result


def ints2selfies(v):
    return "".join([vocab[i] for i in v])


# test them out
s = selfies_list[0]
print('selfies:', s)
v = selfies2ints(s)
print('selfies2ints:', v)
so = ints2selfies(v)
print('ints2selfes:', so)
assert so == s.replace(
    '.', ''
)  # make sure '.' is removed from Selfies string during assertion
selfies: [C][N][Branch1][C][C][C][=Branch1][C][=O][N][C][=C][C][=C][Branch1][C][Cl][C][Branch1][C][Cl][=C][Ring1][Branch2]
selfies2ints: [30, 6, 17, 30, 30, 30, 25, 30, 22, 6, 30, 12, 30, 12, 17, 30, 13, 30, 17, 30, 13, 12, 33, 16]
ints2selfes: [C][N][Branch1][C][C][C][=Branch1][C][=O][N][C][=C][C][=C][Branch1][C][Cl][C][Branch1][C][Cl][=C][Ring1][Branch2]
# creating an object
@dataclass
class Config:
    vocab_size: int
    example_number: int
    batch_size: int
    buffer_size: int
    embedding_dim: int
    rnn_units: int
    hidden_dim: int


config = Config(
    vocab_size=len(vocab),
    example_number=len(selfies_list),
    batch_size=16,
    buffer_size=10000,
    embedding_dim=256,
    hidden_dim=128,
    rnn_units=128,
)
# now get sequences
encoded = [selfies2ints(s) for s in selfies_list if s is not None]
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")

# Now build dataset
data = tf.data.Dataset.from_tensor_slices(
    (padded_seqs, soldata.Solubility.iloc[[bool(s) for s in selfies_list]].values)
)
# now split into val, test, train and batch
N = len(data)
split = int(0.1 * N)
test_data = data.take(split).batch(config.batch_size)
nontest = data.skip(split)
val_data, train_data = nontest.take(split).batch(config.batch_size), nontest.skip(
    split
).shuffle(config.buffer_size).batch(config.batch_size).prefetch(
    tf.data.experimental.AUTOTUNE
)
2022-05-11 18:16:34.130745: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.8.12/x64/lib
2022-05-11 18:16:34.130783: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-05-11 18:16:34.130806: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (fv-az178-670): /proc/driver/nvidia/version does not exist
2022-05-11 18:16:34.131139: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
model = tf.keras.Sequential()

# make embedding and indicate that 0 should be treated as padding mask
model.add(
    tf.keras.layers.Embedding(
        input_dim=config.vocab_size, output_dim=config.embedding_dim, mask_zero=True
    )
)

# RNN layer
model.add(tf.keras.layers.GRU(config.rnn_units))
# a dense hidden layer
model.add(tf.keras.layers.Dense(config.hidden_dim, activation="relu"))
# regression, so no activation
model.add(tf.keras.layers.Dense(1))

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       (None, None, 256)         12288     
                                                                 
 gru (GRU)                   (None, 128)               148224    
                                                                 
 dense (Dense)               (None, 128)               16512     
                                                                 
 dense_1 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 177,153
Trainable params: 177,153
Non-trainable params: 0
_________________________________________________________________
model.compile(tf.optimizers.Adam(1e-4), loss="mean_squared_error")
# verbose=0 silences output, to get progress bar set verbose=1
result = model.fit(train_data, validation_data=val_data, epochs=100, verbose=0)
model.save("solubility-rnn-accurate")
# model = tf.keras.models.load_model('solubility-rnn-accurate/')
2022-05-11 18:17:41.890399: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: solubility-rnn-accurate/assets
INFO:tensorflow:Assets written to: solubility-rnn-accurate/assets
WARNING:absl:<keras.layers.recurrent.GRUCell object at 0x7f836690c6d0> has the same name 'GRUCell' as a built-in Keras object. Consider renaming <class 'keras.layers.recurrent.GRUCell'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
plt.figure(figsize=(5, 3.5))
plt.plot(result.history["loss"], label="training")
plt.plot(result.history["val_loss"], label="validation")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig("rnn-loss.png", bbox_inches="tight", dpi=300)
plt.show()
../_images/219b1fb0b2420eb66b3362548e391143d26c9433ef3c079f1dae7e4080d6f2e6.png
yhat = []
test_y = []
for x, y in test_data:
    yhat.extend(model(x).numpy().flatten())
    test_y.extend(y.numpy().flatten())
yhat = np.array(yhat)
test_y = np.array(test_y)

# plot test data
plt.figure(figsize=(5, 3.5))
plt.plot(test_y, test_y, ":")
plt.plot(test_y, yhat, ".")
plt.text(
    max(test_y) - 6,
    min(test_y) + 1,
    f"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}",
)
plt.text(
    max(test_y) - 6, min(test_y), f"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}"
)
plt.xlabel(r"$y$")
plt.ylabel(r"$\hat{y}$")
plt.title("Testing Data")
plt.savefig("rnn-fit.png", dpi=300, bbox_inches="tight")
plt.show()
../_images/45d58487f2a9baa4c90240df31d4fbec2c8a687b80e92472be99e98b39459316.png

LIME explanations

In the following example, we find out what descriptors influence solubility of a molecules. For example, let’s say we have a molecule with LogS=1.5. We create a perturbed chemical space around that molecule using stoned method and then use lime to find out which descriptors affect solubility predictions for that molecule.

Wrapper function for RNN, to use in STONED

# Predictor function is used as input to sample_space function
def predictor_function(smile_list, selfies):
    encoded = [selfies2ints(s) for s in selfies]
    # check for nans
    valid = [1.0 if sum(e) > 0 else np.nan for e in encoded]
    encoded = [np.nan_to_num(e, nan=0) for e in encoded]
    padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(encoded, padding="post")
    labels = np.reshape(model.predict(padded_seqs), (-1))
    return labels * valid

Descriptor explanations

# Make sure SMILES doesn't contain multiple fragments
smi = soldata.SMILES[0]
stoned_kwargs = {
    "num_samples": 2500,
    "alphabet": exmol.get_basic_alphabet(),
    "max_mutations": 1,
}
space = exmol.sample_space(
    smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
from IPython.display import display, SVG

desc_type = ["Classic", "ECFP", "MACCS"]

for d in desc_type:
    beta = exmol.lime_explain(space, descriptor_type=d)
    if d == "Classic":
        exmol.plot_descriptors(space, d, output_file=f"{d}.svg")
    else:
        svg = exmol.plot_descriptors(space, d, output_file=f"{d}.svg")
        plt.close()
        skunk.display(svg)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
../_images/463339fa87d35a1a06e7e29e274a6642db772eb47f74bbd6873ebcb555f85b39.png
fkw = {"figsize": (6, 4)}
font = {"family": "normal", "weight": "normal", "size": 16}

fig = plt.figure(figsize=(10, 5))
mpl.rc("axes", titlesize=12)
mpl.rc("font", size=16)
ax_dict = fig.subplot_mosaic("AABBB")

# Plot space by fit
svg = exmol.plot_utils.plot_space_by_fit(
    space,
    [space[0]],
    figure_kwargs=fkw,
    mol_size=(200, 200),
    offset=1,
    ax=ax_dict["B"],
    beta=beta,
)
# Compute y_wls
w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in space])
non_zero = w > 10 ** (-6)
w = w[non_zero]
N = w.shape[0]

ys = np.array([e.yhat for e in space])[non_zero].reshape(N).astype(float)
x_mat = np.array([list(e.descriptors.descriptors) for e in space])[non_zero].reshape(
    N, -1
)
y_wls = x_mat @ beta
y_wls += np.mean(ys)

lower = np.min(ys)
higher = np.max(ys)

# set transparency using w
norm = plt.Normalize(min(w), max(w))
cmap = plt.cm.Oranges(w)
cmap[:, -1] = w


def weighted_mean(x, w):
    return np.sum(x * w) / np.sum(w)


def weighted_cov(x, y, w):
    return np.sum(w * (x - weighted_mean(x, w)) * (y - weighted_mean(y, w))) / np.sum(w)


def weighted_correlation(x, y, w):
    return weighted_cov(x, y, w) / np.sqrt(
        weighted_cov(x, x, w) * weighted_cov(y, y, w)
    )


corr = weighted_correlation(ys, y_wls, w)

ax_dict["A"].plot(
    np.linspace(lower, higher, 100), np.linspace(lower, higher, 100), "--", linewidth=2
)
sc = ax_dict["A"].scatter(ys, y_wls, s=50, marker=".", c=cmap, cmap=cmap)
ax_dict["A"].text(max(ys) - 3, min(ys) + 1, f"weighted \ncorrelation = {corr:.3f}")
ax_dict["A"].set_xlabel(r"$\hat{y}$")
ax_dict["A"].set_ylabel(r"$g$")
ax_dict["A"].set_title("Weighted Least Squares Fit")
ax_dict["A"].set_xlim(lower, higher)
ax_dict["A"].set_ylim(lower, higher)
ax_dict["A"].set_aspect(1.0 / ax_dict["A"].get_data_ratio(), adjustable="box")
sm = plt.cm.ScalarMappable(cmap=plt.cm.Oranges, norm=norm)
cbar = plt.colorbar(sm, orientation="horizontal", pad=0.15, ax=ax_dict["A"])
cbar.set_label("Chemical similarity")
plt.tight_layout()
plt.savefig("weighted_fit.svg", dpi=300, bbox_inches="tight", transparent=False)
../_images/8a701430264e358dd2e1595cb73a6c49e4b0fed7c1734dc149368271502ca6c3.png

Robustness to incomplete sampling

We first sample a reference chemical space, and then subsample smaller chemical spaces from this reference. Rank correlation is computed between important descriptors for the smaller subspaces and the reference space.

# Sample a big space
stoned_kwargs = {
    "num_samples": 5000,
    "alphabet": exmol.get_basic_alphabet(),
    "max_mutations": 2,
}
space = exmol.sample_space(
    smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
)
len(space)
3128
# get descriptor attributions
exmol.lime_explain(space, "MACCS", return_beta=False)
# Assign feature ids for rank comparison
features = features = {
    a: b
    for a, b in zip(
        space[0].descriptors.descriptor_names,
        np.arange(len(space[0].descriptors.descriptors)),
    )
}
# Get set of ranks for the reference space
baseline_imp = {
    a: b
    for a, b in zip(space[0].descriptors.descriptor_names, space[0].descriptors.tstats)
    if not np.isnan(b)
}
baseline_imp = dict(
    sorted(baseline_imp.items(), key=lambda item: abs(item[1]), reverse=True)
)
baseline_set = [features[x] for x in baseline_imp.keys()]
# Get subsets and calculate lime importances - subsample - get rank correlation
from scipy.stats import spearmanr

plt.figure(figsize=(4, 3))
N = len(space)
size = np.arange(500, N, 1000)
rank_corr = {N: 1}
for i, f in enumerate(size):
    # subsample space
    rank_corr[f] = []
    for _ in range(10):
        # subsample space of size f
        idx = np.random.choice(np.arange(N), size=f, replace=False)
        subspace = [space[i] for i in idx]
        # get desc attributions
        ss_beta = exmol.lime_explain(subspace, descriptor_type="MACCS")
        ss_imp = {
            a: b
            for a, b in zip(
                subspace[0].descriptors.descriptor_names, subspace[0].descriptors.tstats
            )
            if not np.isnan(b)
        }
        ss_imp = dict(
            sorted(ss_imp.items(), key=lambda item: abs(item[1]), reverse=True)
        )
        ss_set = [features[x] for x in ss_imp.keys()]
        # Get ranks for subsampled space and compare with reference
        ranks = {a: [b] for a, b in zip(baseline_set[:5], np.arange(1, 6))}
        for j, s in enumerate(ss_set):
            if s in ranks:
                ranks[s].append(j + 1)
        # compute rank correlation
        r = spearmanr(np.arange(1, 6), [ranks[x][1] for x in ranks])
        rank_corr[f].append(r.correlation)

    plt.scatter(f, np.mean(rank_corr[f]), color="#13254a", marker="o")

plt.scatter(N, 1.0, color="red", marker="o")
plt.axvline(x=N, linestyle=":", color="red")
plt.xlabel("Size of chemical space")
plt.ylabel("Rank correlation")
plt.tight_layout()
plt.savefig("rank correlation.svg", dpi=300, bbox_inches="tight")
../_images/bf924d44286d1172d4527d22275ffd1ebcc02907c2039487d7b3556446efef96.png

Effect of mutation number, alphabet and size of chemical space

# Mutation
desc_type = ["Classic"]
muts = [1, 2, 3]
for i in muts:
    stoned_kwargs = {
        "num_samples": 2500,
        "alphabet": exmol.get_basic_alphabet(),
        "min_mutations": i,
        "max_mutations": i,
    }
    space = exmol.sample_space(
        smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
    )
    for d in desc_type:
        exmol.lime_explain(space, descriptor_type=d)
        exmol.plot_descriptors(
            space, d, output_file=f"desc_{d}_mut{i}.svg", title=f"Mutations={i}"
        )
../_images/bf206220e8dd438f6e70b7a405a9579c3eefdffd89456b82c56b0d2976a4eb5a.png ../_images/909e78b801de8068af54f9c5bb88e66d7f3bfadf6c6cb8268f3efb17b5b6fc4f.png ../_images/8cdc38c58de09abe31fc9692a5e668616b1ad3ac8ed78a9923bb3b53dd2d4403.png
# Alphabet
basic = exmol.get_basic_alphabet()
train = sf.get_alphabet_from_selfies([s for s in selfies_list if s is not None])
wide = sf.get_semantic_robust_alphabet()
desc_type = ["MACCS"]
alphs = {"Basic": basic, "Training Data": train, "SELFIES": wide}
for a in alphs:
    stoned_kwargs = {"num_samples": 2500, "alphabet": alphs[a], "max_mutations": 2}
    space = exmol.sample_space(
        smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
    )
    for d in desc_type:
        exmol.lime_explain(space, descriptor_type=d)
        svg = exmol.plot_descriptors(
            space, d, output_file=f"desc_{d}_alph_{a}.svg", title=f"Alphabet: {a}"
        )
        plt.close()
        skunk.display(svg)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
# Size of space
desc_type = ["MACCS"]
space_size = [1500, 2000, 2500]
for s in space_size:
    stoned_kwargs = {
        "num_samples": s,
        "alphabet": exmol.get_basic_alphabet(),
        "max_mutations": 2,
    }
    space = exmol.sample_space(
        smi, predictor_function, stoned_kwargs=stoned_kwargs, quiet=True
    )
    for d in desc_type:
        exmol.lime_explain(space, descriptor_type=d)
        svg = exmol.plot_descriptors(
            space,
            d,
            output_file=f"desc_{d}_size_{a}.svg",
            title=f"Chemical space size={s}",
        )
        plt.close()
        skunk.display(svg)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)
SMARTS annotations for MACCS descriptors were created using SMARTSviewer (smartsview.zbh.uni-hamburg.de, Copyright: ZBH, Center for Bioinformatics Hamburg) developed by K. Schomburg et. al. (J. Chem. Inf. Model. 2010, 50, 9, 1529–1535)